import json
from typing import List, Optional

from fastapi import Body
from sse_starlette import EventSourceResponse

from chatchat.configs import (
    DEFAULT_EMBEDDING_MODEL,
    DEFAULT_VS_TYPE,
    OVERLAP_SIZE,
    log_verbose,
    logger,
)
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
from chatchat.server.knowledge_base.kb_summary.base import KBSummaryService
from chatchat.server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from chatchat.server.knowledge_base.utils import list_files_from_folder
from chatchat.server.utils import BaseResponse, get_ChatOpenAI, wrap_done


def recreate_summary_vector_store(
    knowledge_base_name: str = Body(..., examples=["samples"]),
    allow_empty_kb: bool = Body(True),
    vs_type: str = Body(DEFAULT_VS_TYPE),
    embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
    file_description: str = Body(""),
    model_name: str = Body(None, description="LLM 模型名称。"),
    temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
    max_tokens: Optional[int] = Body(
        None, description="限制LLM生成Token数量，默认None代表模型最大值"
    ),
):
    """
    重建单个知识库文件摘要
    :param max_tokens:
    :param model_name:
    :param temperature:
    :param file_description:
    :param knowledge_base_name:
    :param allow_empty_kb:
    :param vs_type:
    :param embed_model:
    :return:
    """

    def output():
        kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
        if not kb.exists() and not allow_empty_kb:
            yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
        else:
            error_msg = f"could not recreate summary vector store because failed to access embed model."
            if not kb.check_embed_model(error_msg):
                yield {"code": 404, "msg": error_msg}
            else:
                # 重新创建知识库
                kb_summary = KBSummaryService(knowledge_base_name, embed_model)
                kb_summary.drop_kb_summary()
                kb_summary.create_kb_summary()

                llm = get_ChatOpenAI(
                    model_name=model_name,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    local_wrap=True,
                )
                reduce_llm = get_ChatOpenAI(
                    model_name=model_name,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    local_wrap=True,
                )
                # 文本摘要适配器
                summary = SummaryAdapter.form_summary(
                    llm=llm, reduce_llm=reduce_llm, overlap_size=OVERLAP_SIZE
                )
                files = list_files_from_folder(knowledge_base_name)

                i = 0
                for i, file_name in enumerate(files):
                    doc_infos = kb.list_docs(file_name=file_name)
                    docs = summary.summarize(
                        file_description=file_description, docs=doc_infos
                    )

                    status_kb_summary = kb_summary.add_kb_summary(
                        summary_combine_docs=docs
                    )
                    if status_kb_summary:
                        logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
                        yield json.dumps(
                            {
                                "code": 200,
                                "msg": f"({i + 1} / {len(files)}): {file_name}",
                                "total": len(files),
                                "finished": i + 1,
                                "doc": file_name,
                            },
                            ensure_ascii=False,
                        )
                    else:
                        msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
                        logger.error(msg)
                        yield json.dumps(
                            {
                                "code": 500,
                                "msg": msg,
                            }
                        )
                    i += 1

    return EventSourceResponse(output())


def summary_file_to_vector_store(
    knowledge_base_name: str = Body(..., examples=["samples"]),
    file_name: str = Body(..., examples=["test.pdf"]),
    allow_empty_kb: bool = Body(True),
    vs_type: str = Body(DEFAULT_VS_TYPE),
    embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
    file_description: str = Body(""),
    model_name: str = Body(None, description="LLM 模型名称。"),
    temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
    max_tokens: Optional[int] = Body(
        None, description="限制LLM生成Token数量，默认None代表模型最大值"
    ),
):
    """
    单个知识库根据文件名称摘要
    :param model_name:
    :param max_tokens:
    :param temperature:
    :param file_description:
    :param file_name:
    :param knowledge_base_name:
    :param allow_empty_kb:
    :param vs_type:
    :param embed_model:
    :return:
    """

    def output():
        kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
        if not kb.exists() and not allow_empty_kb:
            yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
        else:
            # 重新创建知识库
            kb_summary = KBSummaryService(knowledge_base_name, embed_model)
            kb_summary.create_kb_summary()

            llm = get_ChatOpenAI(
                model_name=model_name,
                temperature=temperature,
                max_tokens=max_tokens,
                local_wrap=True,
            )
            reduce_llm = get_ChatOpenAI(
                model_name=model_name,
                temperature=temperature,
                max_tokens=max_tokens,
                local_wrap=True,
            )
            # 文本摘要适配器
            summary = SummaryAdapter.form_summary(
                llm=llm, reduce_llm=reduce_llm, overlap_size=OVERLAP_SIZE
            )

            doc_infos = kb.list_docs(file_name=file_name)
            docs = summary.summarize(file_description=file_description, docs=doc_infos)

            status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
            if status_kb_summary:
                logger.info(f" {file_name} 总结完成")
                yield json.dumps(
                    {
                        "code": 200,
                        "msg": f"{file_name} 总结完成",
                        "doc": file_name,
                    },
                    ensure_ascii=False,
                )
            else:
                msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
                logger.error(msg)
                yield json.dumps(
                    {
                        "code": 500,
                        "msg": msg,
                    }
                )

    return EventSourceResponse(output())


def summary_doc_ids_to_vector_store(
    knowledge_base_name: str = Body(..., examples=["samples"]),
    doc_ids: List = Body([], examples=[["uuid"]]),
    vs_type: str = Body(DEFAULT_VS_TYPE),
    embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
    file_description: str = Body(""),
    model_name: str = Body(None, description="LLM 模型名称。"),
    temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
    max_tokens: Optional[int] = Body(
        None, description="限制LLM生成Token数量，默认None代表模型最大值"
    ),
) -> BaseResponse:
    """
    单个知识库根据doc_ids摘要
    :param knowledge_base_name:
    :param doc_ids:
    :param model_name:
    :param max_tokens:
    :param temperature:
    :param file_description:
    :param vs_type:
    :param embed_model:
    :return:
    """
    kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
    if not kb.exists():
        return BaseResponse(
            code=404, msg=f"未找到知识库 {knowledge_base_name}", data={}
        )
    else:
        llm = get_ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            local_wrap=True,
        )
        reduce_llm = get_ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            local_wrap=True,
        )
        # 文本摘要适配器
        summary = SummaryAdapter.form_summary(
            llm=llm, reduce_llm=reduce_llm, overlap_size=OVERLAP_SIZE
        )

        doc_infos = kb.get_doc_by_ids(ids=doc_ids)
        # doc_infos转换成DocumentWithVSId包装的对象
        doc_info_with_ids = [
            DocumentWithVSId(**doc.dict(), id=with_id)
            for with_id, doc in zip(doc_ids, doc_infos)
        ]

        docs = summary.summarize(
            file_description=file_description, docs=doc_info_with_ids
        )

        # 将docs转换成dict
        resp_summarize = [{**doc.dict()} for doc in docs]

        return BaseResponse(
            code=200, msg="总结完成", data={"summarize": resp_summarize}
        )
